{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 准备环境"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Use one of the following commands to start using PyAlink:\n",
      " - useLocalEnv(parallelism, flinkHome=None, config=None)\n",
      " - useRemoteEnv(host, port, parallelism, flinkHome=None, localIp=\"localhost\", config=None)\n",
      "Call resetEnv() to reset environment and switch to another.\n",
      "\n",
      "JVM listening on 127.0.0.1:50568\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "MLEnv(benv=JavaObject id=o2, btenv=JavaObject id=o5, senv=JavaObject id=o3, stenv=JavaObject id=o6)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# set env\n",
    "from pyalink.alink import *\n",
    "resetEnv()\n",
    "useLocalEnv(1, config=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## read data\n",
    "URL = \"https://alink-release.oss-cn-beijing.aliyuncs.com/data-files/review_rating_train.csv\"\n",
    "SCHEMA_STR = \"review_id bigint, rating5 bigint, rating3 bigint, review_context string\"\n",
    "LABEL_COL = \"rating5\"\n",
    "TEXT_COL = \"review_context\"\n",
    "VECTOR_COL = \"vec\"\n",
    "PRED_COL = \"pred\"\n",
    "PRED_DETAIL_COL = \"predDetail\"\n",
    "source = CsvSourceBatchOp() \\\n",
    "    .setFilePath(URL)\\\n",
    "    .setSchemaStr(SCHEMA_STR)\\\n",
    "    .setFieldDelimiter(\"_alink_\")\\\n",
    "    .setQuoteChar(None)\n",
    "\n",
    "## Split data for train and test\n",
    "trainData = SplitBatchOp().setFraction(0.9).linkFrom(source)\n",
    "testData = trainData.getSideOutput(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征工程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline = (\n",
    "    Pipeline()\n",
    "    .add(\n",
    "        Segment()\n",
    "        .setSelectedCol(TEXT_COL)\n",
    "    )\n",
    "    .add(\n",
    "        StopWordsRemover()\n",
    "        .setSelectedCol(TEXT_COL)\n",
    "    ).add(\n",
    "        DocHashCountVectorizer()\n",
    "        .setFeatureType(\"WORD_COUNT\")\n",
    "        .setSelectedCol(TEXT_COL)\n",
    "        .setOutputCol(VECTOR_COL)\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## naiveBayes model\n",
    "naiveBayes = (\n",
    "    NaiveBayesTextClassifier()\n",
    "    .setVectorCol(VECTOR_COL)\n",
    "    .setLabelCol(LABEL_COL)\n",
    "    .setPredictionCol(PRED_COL)\n",
    "    .setPredictionDetailCol(PRED_DETAIL_COL)\n",
    ")\n",
    "model = pipeline.add(naiveBayes).fit(trainData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据预测评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## evaluation\n",
    "predict = model.transform(testData)\n",
    "metrics = (\n",
    "    EvalMultiClassBatchOp()\n",
    "    .setLabelCol(LABEL_COL)\n",
    "    .setPredictionDetailCol(PRED_DETAIL_COL)\n",
    "    .linkFrom(predict)\n",
    "    .collectMetrics()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 打印评估结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ConfusionMatrix: [[4987, 327, 229, 204, 292], [28, 1223, 164, 147, 108], [1, 1, 269, 10, 11], [0, 0, 0, 10, 0], [0, 2, 1, 2, 83]]\n",
      "LabelArray: ['5', '4', '3', '2', '1']\n",
      "LogLoss: 2.330945631084851\n",
      "Accuracy: 0.8114582047166317\n",
      "Kappa: 0.6190950197563011\n",
      "MacroF1: 0.5123859853163818\n",
      "Label 1 Accuracy: 0.9486356340288925\n",
      "Label 1 Kappa: 0.27179135595030096\n",
      "Label 1 Precision: 0.9431818181818182\n"
     ]
    }
   ],
   "source": [
    "print(\"ConfusionMatrix:\", metrics.getConfusionMatrix())\n",
    "print(\"LabelArray:\", metrics.getLabelArray())\n",
    "print(\"LogLoss:\", metrics.getLogLoss())\n",
    "print(\"Accuracy:\", metrics.getAccuracy())\n",
    "print(\"Kappa:\", metrics.getKappa())\n",
    "print(\"MacroF1:\", metrics.getMacroF1())\n",
    "print(\"Label 1 Accuracy:\", metrics.getAccuracy(\"1\"))\n",
    "print(\"Label 1 Kappa:\", metrics.getKappa(\"1\"))\n",
    "print(\"Label 1 Precision:\", metrics.getPrecision(\"1\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
